# -*- coding: utf-8 -*- #
"""
Time                2022/10/11 16:13
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                executor.py
Description:
"""
import time
import asyncio
from queue import Queue
from typing import Callable, Optional
from cec_base.event import Event
from cec_base.consumer import Consumer
from cec_base.cec_client import MultiConsumer, CecAsyncConsumeTask, StoppableThread
from clogger import logger
from importlib import import_module
from conf.settings import *
from sysom_utils import CecTarget


class AsyncMultiConsumer(MultiConsumer):
    def __init__(
        self,
        url: str,
        sync_mode: bool = False,
        custom_callback: Callable[[Event, CecAsyncConsumeTask], None] = None,
        **kwargs,
    ):
        super().__init__(url, sync_mode, custom_callback, **kwargs)

        # 执行任务的线程池数量
        self._task_process_thread: Optional[StoppableThread] = None
        self._task_queue: Queue = Queue(maxsize=1000)

    def add_async_task(self, task: asyncio.Task):
        self._task_queue.put(task)

    def _process_task(self):
        def _get_task_from_queue():
            _tasks = []
            while not self._task_queue.empty():
                _task = self._task_queue.get_nowait()
                if _task:
                    _tasks.append(_task)
                else:
                    break
            return _tasks

        tasks = _get_task_from_queue()
        loop = asyncio.new_event_loop()
        assert self._task_process_thread is not None
        while not self._task_process_thread.stopped():
            if len(tasks) == 0:
                time.sleep(0.1)
                tasks = _get_task_from_queue()
                continue
            finished, unfinished = loop.run_until_complete(
                asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, timeout=0.5)
            )
            for task in finished:
                if task.exception() is not None:
                    logger.error(str(task.exception()))
                else:
                    pass
            tasks = _get_task_from_queue()
            if unfinished is not None:
                tasks += list(unfinished)

    def start(self):
        super().start()
        if (
            self._task_process_thread is not None
            and not self._task_process_thread.stopped()
            and self._task_process_thread.is_alive()
        ):
            return
        self._task_process_thread = StoppableThread(target=self._process_task)
        self._task_process_thread.setDaemon(True)
        self._task_process_thread.start()


class AlertListener(AsyncMultiConsumer):
    """A cec-based channel listener

    A cec-based channel lilster, ssed to listen to requests for channels from
    other modules and output the results to cec after performing the corresponding
    operation on the target node

    Args:
        task_process_thread_num(str): The number of threads contained in the thread
                                       pool used to execute the task

    """

    def __init__(self) -> None:
        super().__init__(
            YAML_CONFIG.get_cec_url(CecTarget.PRODUCER),
            custom_callback=self.on_receive_event,
        )
        self.append_group_consume_task(
            CEC_TOPIC_SYSOM_SAD_ALERT,
            "sysom_alert_pusher",
            Consumer.generate_consumer_id(),
            ensure_topic_exist=True,
        )

        self.rules = {}
        self.targets = {}

        # Initial all targets
        for target_type, targets in service_config.get("push_targets", {}).items():
            target_class = self._get_push_target_class(target_type)
            for target_name, target_config in targets.items():
                self.targets[f"{target_type}.{target_name}"] = target_class(
                    target_config
                )

        # Initial all rules
        for rule_type, rules in service_config.get("push_rules", {}).items():
            rule_class = self._get_push_rule_class(rule_type)
            for rule_name, rule_config in rules.items():
                self.rules[f"{rule_type}.{rule_name}"] = rule_class(rule_config)

        # 执行任务的线程池数量
        self._task_process_thread: Optional[StoppableThread] = None
        self._task_queue: Queue = Queue(maxsize=1000)

    def _get_push_target_class(self, target_type):
        """
        Get the push target class according to the target type
        """
        try:
            return getattr(
                import_module(f"lib.targets.{target_type}"),
                f"PushTarget{target_type.title()}",
            )
        except Exception as e:
            raise Exception(f"No channels available => {str(e)}")

    def _get_push_rule_class(self, rule_name):
        """
        Get the push rule class according to the rule name
        """
        try:
            return getattr(
                import_module(f"lib.rules.{rule_name}"), f"PushRule{rule_name.title()}"
            )
        except Exception as e:
            raise Exception(f"No rules available => {str(e)}")

    def _deal_recevied_data(self, data: dict):
        """
        处理接收到的数据
        """
        for rule in self.rules.values():
            if rule.is_match(data):
                for target_name in rule.get_targets():
                    target = self.targets.get(target_name)
                    if target is None:
                        logger.warning(f"Target not found, target = {target_name}")
                        continue
                    self.add_async_task(
                        target.push(data)
                    )

    def on_receive_event(self, event: Event, task: CecAsyncConsumeTask):
        """
        处理每个单独的任务
        """
        event_value = event.value
        try:
            assert isinstance(event_value, dict)
            if task.topic_name == CEC_TOPIC_SYSOM_SAD_ALERT:
                self._deal_recevied_data(event_value)
            else:
                logger.warning(
                    f"Received not expect topic data, topic = {task.topic_name}"
                )
        except Exception as e:
            logger.exception(e)
        finally:
            # 执行消息确认
            task.ack(event)
